<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>micrograd</title>
    <style>
        body, html {
            margin: 0;
            padding: 0;
            padding-bottom: 100px;
        }
        h1 {
            text-align: center;
            font-family: 'Arial', sans-serif;
            color: #333;
            padding: 5px;
            margin: 0px;
            font-size: 20px;
        }
        #canvas-div {
            margin: 5px;
            max-height: 510px;
        }
        #decision-canvas {
            border: 1px solid black;
        }
        .container {
            display: flex;
            justify-content: flex-start;
            align-items: flex-start;
            max-width: 100%;
            overflow-x: hidden;
        }
        #optimizer-div {
            margin: 5px;
            font-size: 16px;
            max-height: 500px;
            overflow-y: auto;
            border: 1px solid black;
            flex: 1;
        }
        #optimizer-div table {
            border-collapse: collapse;
        }
        #optimizer-div th, #optimizer-div td {
            border: 1px solid #ddd;
            padding: 4px 8px;
            text-align: right;
        }
        #optimizer-div th {
            background-color: #f2f2f2;
        }
        #graph-div {
            margin: 5px;
            width: calc(100% - 10px);  /* Subtract margin from width */
            border: 1px solid black;
            overflow: hidden;
        }
        #graph-div img {
            max-width: 100%;
            height: auto;
            display: block;
        }
        #controls-div {
            margin: 5px;
            padding: 20px;
            color: #333;
            font-size: 18px;
            flex: 1;
        }
        .control-btn {
            margin: 0px;
            padding: 5px 10px;
            font-size: 20px;
            cursor: pointer;
            background-color: white;
            border: 1px solid #ddd;
            border-radius: 4px;
        }
    </style>
</head>
<body>
    <h1>micograd live demo</h1>

    <!-- The SVG that displays the computational graph of the MLP goes here -->
    <div class="container">
        <div id="graph-div">
            <object id="svg-object" type="image/svg+xml" data="graph.svg" width="100%" height="100%">
                <div>Either your browser does not support SVG</div>
                <div>Or make sure to run `python micrograd.py` to generate the graph.svg file</div>
            </object>
        </div>
    </div>

    <!-- Container for the control panel -->
    <div class="container">
        <div id="controls-div">
            <!-- Here we can set the datapoint of interest for the graph visualization -->
            <div>
                <span>Compute graph visualization datapoint:</span>
                <label>X: <input type="number" id="fiction-x" value="0.0" step="0.1" style="width: 50px;"></label>
                <label>Y: <input type="number" id="fiction-y" value="0.0" step="0.1" style="width: 50px;"></label>
                <label>Label: <input type="number" id="fiction-label" value="0" min="0" max="2" step="1"></label>
                <button id="apply-fiction" class="control-btn">Apply</button>
            </div>
            <!-- Here we can toggle the level set lines on and off -->
            <label>
                <input type="checkbox" id="show-level-sets" checked> Show level set lines
            </label>
            <!-- Here we can reset the demo, toggle the training loop on and off, and step through the optimization -->
            <div>
                <button id="reset-btn" class="control-btn" title="Reset">reset</button>
                <button id="toggle-btn" class="control-btn" title="Play/Pause">play/pause</button>
                <button id="step-btn" class="control-btn" title="Step">step</button>
            </div>
        </div>
    </div>

    <div class="container">
        <!-- The canvas div that displays the current decision boundary of the MLP -->
        <div id="canvas-div">
            <canvas id="decision-canvas" width="500" height="500">Browser not supported for Canvas.</canvas><br /><br />
        </div>
        <!-- The div that displays the AdamW optimizer state: param, grad, m, v -->
        <div id="optimizer-div">
        </div>
    </div>

    <!-- The micrograd JavaScript lib -->
    <script src="micrograd.js"></script>

<!-- The "int main()" of the demo-->
<script>

// global variables that hold the state of the demo
let random;
let dataSplits, trainSplit, valSplit, testSplit;
let model;
let optimizer;
// optimization state
let step = 0;
let numSteps = 100;
let isPaused = false; // whether the training loop is paused
let animationId = null; // id of the animation frame
// the "fiction" datapoint that we will use to trace the graph
let fwd_dataset = [[[new Value(0.0), new Value(0.0)], 0]];

function lossFun(model, split) {
    // evaluate the loss function on a given data split
    let totalLoss = new Value(0.0);
    for (const [x, y] of split) {
        const logits = model.forward(x);
        const loss = crossEntropy(logits, y);
        totalLoss = totalLoss.add(loss);
    }
    const meanLoss = totalLoss.mul(1.0 / split.length);
    return meanLoss;
}

// reset the demo to its initial state
function reset() {
    // Create an instance of RNG with seed 42
    random = new RNG(42);
    // Generate data using the genData function
    dataSplits = genDataYinYang(random, 100);
    trainSplit = dataSplits.train;
    valSplit = dataSplits.validation;
    testSplit = dataSplits.test;
    // init the model: 2D inputs, 8 neurons, 3 outputs (logits)
    model = new MLP(2, [8, 3]);
    // optimize using AdamW
    optimizer = new AdamW(model.parameters(), 1e-1, [0.9, 0.95], 1e-8, 1e-4);
    // reset the step counter
    step = 0;
    // update all the vis with the initial state
    const trainLoss = lossFun(model, trainSplit);
    optimizer.zeroGrad();
    trainLoss.backward();
    renderCanvas(); // show the top down view
    renderOptimizerState(); // show the optimizer state
    renderGraph(); // show the graph of nodes and edges
}
reset();
// when the reset button is clicked, reset the demo
document.getElementById('reset-btn').addEventListener('click', reset);

function renderCanvas(minX = -2, minY = -2, maxX = 2, maxY = 2) {
    // first render the datapoints
    const canvas = document.getElementById('decision-canvas');
    const ctx = canvas.getContext('2d');

    // Clear the canvas
    ctx.clearRect(0, 0, canvas.width, canvas.height);

    // Function to map data points to canvas coordinates
    function mapToCanvas(x, y) {
        const canvasX = (x - minX) * (canvas.width / (maxX - minX));
        const canvasY = (maxY - y) * (canvas.height / (maxY - minY));
        return [canvasX, canvasY];
    }

    // Render the current decision surface
    const stepSize = 0.1;
    const rectWidth = stepSize * canvas.width / (maxX - minX);
    const rectHeight = stepSize * canvas.height / (maxY - minY);
    for (let x = minX; x < maxX; x += stepSize) {
        for (let y = minY; y < maxY; y += stepSize) {
            const centerX = x + stepSize / 2;
            const centerY = y + stepSize / 2;
            const logits = model.forward([new Value(centerX), new Value(centerY)]);
            const exps = logits.map(logit => Math.exp(logit.data));
            const sumExps = exps.reduce((a, b) => a + b, 0);
            const probs = exps.map(exp => exp / sumExps);

            const r = Math.floor(probs[0] * 255);
            const g = Math.floor(probs[1] * 255);
            const b = Math.floor(probs[2] * 255);

            const [canvasX, canvasY] = mapToCanvas(x, y);
            const [canvasX2, canvasY2] = mapToCanvas(x + stepSize, y + stepSize);
            const width = canvasX2 - canvasX;
            const height = canvasY - canvasY2;
            const mutedR = Math.floor(r + (255 - r) * 0.5);
            const mutedG = Math.floor(g + (255 - g) * 0.5);
            const mutedB = Math.floor(b + (255 - b) * 0.5);
            ctx.fillStyle = `rgb(${mutedR},${mutedG},${mutedB})`;
            ctx.strokeStyle = `rgb(${mutedR},${mutedG},${mutedB})`;
            ctx.fillRect(canvasX, canvasY2, width, height);
            ctx.strokeRect(canvasX, canvasY2, width, height);
        }
    }

    // Render training data points
    for (const [x, y] of trainSplit) {
        const [canvasX, canvasY] = mapToCanvas(x[0], x[1]);
        ctx.fillStyle = y === 0 ? 'red' : y === 1 ? 'green' : 'blue';
        ctx.beginPath();
        ctx.arc(canvasX, canvasY, 5, 0, 2 * Math.PI);
        ctx.fill();
        ctx.strokeStyle = 'black';
        ctx.stroke();
    }

    // Render the 0-level set of all individual neurons
    const showLevelSets = document.getElementById('show-level-sets').checked;
    if (showLevelSets) {
        for (const neuron of model.layers[0].neurons) {
            const w0 = neuron.w[0].data;
            const w1 = neuron.w[1].data;
            const b = neuron.b.data;
            const x1 = -2;
            const y1 = (-b - w0 * x1) / w1;
            const x2 = 2;
            const y2 = (-b - w0 * x2) / w1;
            const [canvasX1, canvasY1] = mapToCanvas(x1, y1);
            const [canvasX2, canvasY2] = mapToCanvas(x2, y2);
            ctx.strokeStyle = 'white';
            ctx.beginPath();
            ctx.moveTo(canvasX1, canvasY1);
            ctx.lineTo(canvasX2, canvasY2);
            ctx.stroke();
        }
    }

    // Top right of the canvas render the step counter
    ctx.fillStyle = 'white';
    ctx.font = '20px Arial';
    ctx.fillText(`step ${step} / ${numSteps}`, canvas.width - 120, 20);

    // Render the fiction datapoint
    const [canvasX, canvasY] = mapToCanvas(fwd_dataset[0][0][0].data, fwd_dataset[0][0][1].data);
    ctx.fillStyle = 'yellow';
    ctx.beginPath();
    ctx.arc(canvasX, canvasY, 5, 0, 2 * Math.PI);
    ctx.fill();
    ctx.strokeStyle = 'black';
    ctx.stroke();
}

function renderOptimizerState(num_columns = 3) {
    const optimizerDiv = document.getElementById('optimizer-div');
    optimizerDiv.innerHTML = '';

    const parameters = model.parameters();
    const rows = Math.ceil(parameters.length / num_columns);

    for (let col = 0; col < num_columns; col++) {
        const table = document.createElement('table');
        table.innerHTML = '<tr><th>param</th><th>-m/sqrt(v)</th><th>grad</th><th>m</th><th>sqrt(v)</th></tr>';
        table.style.display = 'inline-block';
        table.style.marginRight = '10px';
        table.style.verticalAlign = 'top';

        for (let row = 0; row < rows; row++) {
            const index = col * rows + row;
            if (index >= parameters.length) break;

            const param = parameters[index];
            // calculate and show the "lookahead" m and v (but without the bias correction)
            const m = optimizer.beta1 * param.m + (1 - optimizer.beta1) * param.grad;
            const v = optimizer.beta2 * param.v + (1 - optimizer.beta2) * (param.grad ** 2);
            const sqrtV = Math.sqrt(v + 1e-8);
            const update = -m / sqrtV; // the -sign is because we update params with -=
            const tableRow = document.createElement('tr');
            tableRow.innerHTML = `
                <td style="color: ${param.data >= 0 ? '#45a049' : '#e06666'}">${param.data.toFixed(4)}</td>
                <td style="color: ${update >= 0 ? '#45a049' : '#e06666'}">${update.toFixed(4)}</td>
                <td style="color: ${param.grad >= 0 ? '#45a049' : '#e06666'}">${param.grad.toFixed(4)}</td>
                <td style="color: ${param.m >= 0 ? '#45a049' : '#e06666'}">${param.m.toFixed(4)}</td>
                <td style="color: ${sqrtV >= 0 ? '#45a049' : '#e06666'}">${sqrtV.toFixed(4)}</td>
            `;
            table.appendChild(tableRow);
        }

        optimizerDiv.appendChild(table);
    }
}

function trace(root) {
    // traces the full graph of nodes and edges starting from the root
    const nodes = [];
    const edges = [];

    function build(v) {
        if (!nodes.includes(v)) {
            nodes.push(v);
            for (const child of v._prev) {
                if (!edges.some(edge => edge[0] === child && edge[1] === v)) {
                    edges.push([child, v]);
                }
                build(child);
            }
        }
    }

    build(root);
    return [nodes, edges];
}

function getGraphNodes() {
    // returns a list of all the nodes in the graph
    const svg = document.getElementById('svg-object').contentDocument;
    const svg_nodes = svg.querySelectorAll('g.node');
    // issue now is that there are two types of nodes:
    // 1) the "operation" nodes for e.g. + / - / * ...
    // 2) the actual "data" nodes. We can identify these by presence of "text" element
    //    that mentions "data", for example. A bit janky but it works.
    const data_nodes = [];
    for (const node of svg_nodes) {
        const text = node.querySelector('text');
        if (text && text.textContent.includes("data")) {
            data_nodes.push(node);
        }
    }
    return data_nodes;
}

function renderGraph() {
    // Step 1: forward some point of interest e.g. the origin (0,0), e.g. label 0:
    const trainLoss = lossFun(model, fwd_dataset);
    trainLoss.backward();
    // Step 2: let's "walk" the graph and collect all the nodes and edges
    const [nodes, edges] = trace(trainLoss);
    // Step 3: get the list of all svg "g" elements of class "node"
    const svg_nodes = getGraphNodes();
    // Error checking: we should have as many nodes as there are svg nodes
    if (svg_nodes.length != nodes.length) {
        console.log("found a total of", svg_nodes.length, "svg nodes");
        console.log("found a total of", nodes.length, "nodes and", edges.length, "edges");
        console.log("ERROR: found a different number of nodes than expected");
        return;
    }
    // print the nodes and their data/grad
    for (let i = 0; i < svg_nodes.length; i++) {
        // there should be exactly two "text" nodes in the svg node
        const svg_node = svg_nodes[i];
        const textNodes = svg_node.querySelectorAll('text');
        if (textNodes.length != 2) {
            console.log("ERROR: found a different number of text nodes than expected");
            return;
        }
        const graph_node = nodes[i];
        // the first text node is the data, second node is the grad
        textNodes[0].textContent = "data: " + graph_node.data.toFixed(4);
        textNodes[1].textContent = "grad: " + graph_node.grad.toFixed(4);
    }
}

function trainAndRenderStep(schedule_next_step = true) {
    if (step < numSteps) {
        step++;

        // get the loss for the training split
        const trainLoss = lossFun(model, trainSplit);
        console.log(`step ${step}, train loss ${trainLoss.data.toFixed(6)}`);

        // backward pass (deposit the gradients, starting at zero gradients)
        optimizer.zeroGrad();
        trainLoss.backward();
        optimizer.step();

        // all the visualizations (data + grad)
        renderCanvas(); // show the top down view
        renderOptimizerState(); // show the optimizer state
        renderGraph(); // show the graph of nodes and edges

        // schedule the next step
        if (schedule_next_step) {
            animationId = setTimeout(() => trainAndRenderStep(true), 200);
        }
    }
}

function applyFictionDataPoint() {
    // update the fiction datapoint
    const x = parseFloat(document.getElementById('fiction-x').value);
    const y = parseFloat(document.getElementById('fiction-y').value);
    const label = parseInt(document.getElementById('fiction-label').value);
    fwd_dataset = [[[new Value(x), new Value(y)], label]];
    // we need to re-render the graph and the canvas
    renderGraph();
    renderCanvas();
    // and color the label
    const color = label === 0 ? 'red' : label === 1 ? 'green' : 'blue';
    document.getElementById('fiction-label').style.color = color;
}
document.getElementById('apply-fiction').addEventListener('click', applyFictionDataPoint);

// toggle play/pausebutton
function toggleTraining() {
    isPaused = !isPaused;
    if (isPaused) {
        clearTimeout(animationId);
        animationId = null;
        // make the toggle button have red background
        document.getElementById('toggle-btn').style.backgroundColor = 'lightcoral';
    } else {
        trainAndRenderStep(true);
        // make the toggle button have green background
        document.getElementById('toggle-btn').style.backgroundColor = 'lightgreen';
    }
}
document.getElementById('toggle-btn').addEventListener('click', toggleTraining);

// run a single step of the optimization (only if paused)
function stepTraining() {
    if (isPaused) {
        trainAndRenderStep(false);
    }
}
document.getElementById('step-btn').addEventListener('click', stepTraining);

// Add event listener for "show level sets" checkbox changes
document.getElementById('show-level-sets').addEventListener('change', () => {
    renderCanvas();
});

// Give the SVG ability to pan/zoom
function enablePanZoom() {
    const svgObject = document.getElementById('svg-object');
    const svg = svgObject.contentDocument.documentElement;
    if (!svg) {
        console.error('SVG not found');
        return;
    }

    let viewBox = svg.viewBox.baseVal;
    let isPanning = false;
    let startPoint = { x: 0, y: 0 };
    let endPoint = { x: 0, y: 0 };
    let scale = 1;

    svg.addEventListener('mousedown', startPan);
    svg.addEventListener('mousemove', pan);
    svg.addEventListener('mouseup', endPan);
    svg.addEventListener('mouseleave', endPan);
    svg.addEventListener('wheel', zoom, { passive: false });

    // Touch events
    svg.addEventListener('touchstart', startPan, { passive: false });
    svg.addEventListener('touchmove', pan, { passive: false });
    svg.addEventListener('touchend', endPan);
    svg.addEventListener('touchcancel', endPan);

    function getPointFromEvent(event) {
        const point = svg.createSVGPoint();
        point.x = event.clientX || (event.touches && event.touches[0].clientX);
        point.y = event.clientY || (event.touches && event.touches[0].clientY);
        return point.matrixTransform(svg.getScreenCTM().inverse());
    }

    function startPan(event) {
        event.preventDefault();
        isPanning = true;
        startPoint = getPointFromEvent(event);
    }

    function pan(event) {
        if (!isPanning) return;
        event.preventDefault();
        endPoint = getPointFromEvent(event);
        const dx = (startPoint.x - endPoint.x) / scale;
        const dy = (startPoint.y - endPoint.y) / scale;
        viewBox.x += dx;
        viewBox.y += dy;
        svg.setAttribute('viewBox', `${viewBox.x} ${viewBox.y} ${viewBox.width} ${viewBox.height}`);
        startPoint = endPoint;
    }

    function endPan() {
        isPanning = false;
    }

    function zoom(event) {
        event.preventDefault();
        const delta = event.deltaY;
        const zoomPoint = getPointFromEvent(event);
        const zoomFactor = delta > 0 ? 1.1 : 0.9;

        const oldWidth = viewBox.width;
        const oldHeight = viewBox.height;
        viewBox.width *= zoomFactor;
        viewBox.height *= zoomFactor;

        viewBox.x += (zoomPoint.x - viewBox.x) * (1 - zoomFactor);
        viewBox.y += (zoomPoint.y - viewBox.y) * (1 - zoomFactor);

        scale /= zoomFactor;

        svg.setAttribute('viewBox', `${viewBox.x} ${viewBox.y} ${viewBox.width} ${viewBox.height}`);
    }
}
document.getElementById('svg-object').addEventListener('load', enablePanZoom);

// Start the training loop
trainAndRenderStep(true);
document.getElementById('toggle-btn').style.backgroundColor = 'lightgreen';
document.getElementById('fiction-label').style.color = 'red'; // at init the label is red for the fiction datapoint

</script>
</body>
</html>